%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ks_het_firms_baseline.m
% Solves a GE heterogeneous firms investment model with macro shocks and real and financial adjustment costs 
% via the Krusell Smith algorithm
% Ivan Alfaro, Nick Bloom and Xiaoji Lin 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

close all; clearvars; clc;

diary_record = 1;
if (diary_record)
    delete 'ks_het_firms.txt';
    diary 'ks_het_firms.txt';
    diary ON
end
%%%%initial program setup
%start tracking timing

tic; 
disp('%%%%%%%%%%%%%%%%%%%')
disp('Solving the het firms model with macro shocks via Krusell Smith')
disp(' ')

%set parameters
beta    = 1/1.0125; %household subjective discount factor
R       = 1/beta;   %implies the real interest rate in GE
delta   = 0.05;     %depreciation rate
alpha   = 0.70;     %capital share
c_k     = 0.03;     %adjustment cost parameter
rf      = R-1;      %constant risk-free rate
kappa   = 0.97;     %wedge between risk-free rate and return on cash

%set up solution
maxfcstit   = 50;   %max number of iterations on the fcst rule
maxfcsterr  = 1e-2; %max error in the fcst rule coefficients
fcstdamp    = 0.01; %fcst rule updates 100*fcstdamp % of the way
maxvfit     = 30;   %max number of iterations on the value function
maxvferr    = 1e-3; %max abs error or change in the value functions
howardnum   = 30;   %number of Howard acceleration steps within VFI steps
maxdistit   = 1000; %max number of distributional iterations
maxdisterr  = 1e-7; %max abs error of change in stationary dist
maxbunch    = 1e-3; %max bunching at grid endpoints
maxpit      = 50;   %max iters to find eqbm price
maxperr     = 1e-4; %max error in clearing price for simulations
fcsterrortol = 0.01; %the convergence tolerance for the forecast rules 
RMSEchangetol = 0.001; %the convergence tolerance for RMSE changes
R2changetol   = 0.01; %the convergence tolerance for R2 changes
maxDenHaanchangetol = 0.01; %the convergence tolerance for Den Haan max statistic
avgDenHaanchangetol = 0.01; %the convergence tolerance for Den Haan avg statistic
%GE convergence criteria
%1 implies convergence of coefficients
%2 implies convergence of RMSE in changes
%3 implies convergence of R^2 in changes
%4 implies convergence of max Den Haan stat in changes
%5 implies convergence of avg Den Haan stat in changes
GEerrorswitch = 4; 

checkbounds = 1;    % check whether simulation hits the boundary

%set up price grid
plb   = 0.7;     %price lower boundary (initial bisection lb)
pub   = 1.8;     %price upper boundary (initial bisection ub)
pnum  = 15;
p0    = linspace(plb,pub,pnum);
%control the price-clearing process
perrortol = 1e-4;
disttol = 1e-4; %tolerance for ignoring this point in the dist
pwindow = 0.1;
pcutoff = 15;

%%%%%%%%%%%%
%%%%This block sets up the grids and productivity transition matrices
%%%%%%%%%%%%
%set grid dimensions for states (z,sigmaz,k,n,A,sigmaA,K,F)
%set capital grid boundaries
knum = 200;
kmin = .2; %micro min
kmax = exp(log(kmin) - (knum-1) * log(1.0-delta));
k0   = exp(linspace(log(kmin),log(kmax),knum))';
knum = length(k0);

%set cash grid boundaries
nmin = .1; %micro min
nmax = 20; %micro max
nnum = 5;
n0  =  ([0 exp(linspace(log(nmin),log(nmax),nnum))]);
nnum = nnum+1;
nmin = min(n0);

%set up macro capital grid
Knum = 10;   %macro capital grid
Kmin = 100;  %macro min
Kmax = 450;  %macro max
K0   = linspace(log(Kmin),log(Kmax),Knum)';
K0   = exp(K0);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% firm time-varying uncertainty
snum        = 2;                                  % number of firm specific uncertainty states
sigmaz_l    = 0.051;                             % low firm-specific sigma from Bloom et al 2018 RUBC
sigmaz      = [sigmaz_l sigmaz_l*4.1];            % low and high firm-specific sigma from Bloom et al 2018 RUBC
% aggregate time-varying uncertainty
Snum        = 2;                                  % number of aggregate uncertainty states
sigmaA_l    = 0.0067;                              % low aggregate sigma from Bloom et al 2018 RUBC
sigmaA      = [sigmaA_l sigmaA_l*1.6];            % low and high aggregate sigma from Bloom et al 2018 RUBC
PiSigma    = ([1-0.026 0.026 ;                    % transition matrix from Bloom et al RUBC 2018 same for aggregate productivity
                  1-0.943 0.943;]);
sigmaAinit      = ceil(snum/2);
Pisigma    = PiSigma;                              % transition matrix from Bloom et al RUBC 2018 same for firm-specific productivity

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% firm-specific productivity
znum    = 5;                                    % number of firm-specific productivity grid points
zbar    = 0;                                    % mean level of productivity
rhoz    = 0.95;                                 % persistence
stdz    = sqrt(sigmaz_l^2/(1-rhoz^2)); 
nsigmaz = 5;                                    % controls the boundary of z
Jen     = 2;                                    % parameter on Jensen adjustment
[z0, Piz] = TransitProb(rhoz,zbar,stdz,sigmaz,znum,snum,nsigmaz,Jen); % transition matrix
z0      = exp(z0)';
zmin    = min(z0);  
zmax    = max(z0);
intit    = (ceil(znum/2));                        % index of the median z

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%set up macro productivity grid
Anum    = 5;                                    % number of aggregate productivity grid points
Abar    = -1;                                    % mean level of productivity
rhoA    = 0.95;                                 % persistence of aggregate productivity
stdA    = sqrt(sigmaA_l^2/(1-rhoA^2));   
nsigmaA = (1.75);
[A0, PiA] = TransitProb(rhoA,Abar,stdA,sigmaA,Anum,Snum,nsigmaA,Jen);
A0      = exp(A0)';
Amin    = (min(A0));
Amax    = (max(A0));

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% financial shock
Fnum     = 2;                       % number of financial shock states
Fin_l    = 0.03;                    % low financial shock state
Fin_h    = 0.06;                    % high financial shock state
FinC   = [Fin_l; Fin_h];            % financing costs  
PiF    = ([0.95 0.05;0.5 0.5]);     % transition matrix of financial shock
%PiF   = PiSigma;                   % alternative transition matrix of financial shock
imF   = (ceil(Fnum/2));

%set up simulation
Tsim = 5000; %how many periods to simulate
Terg = 1000; %how many periods to discard to remove influence of initialization
Ttot = Tsim + Terg; %implied total number of periods
% plotrange = (400:500); %periods of the simulation to plot
kinit = floor(knum/2); %starting point for micro capital simulation
ninit = ceil(nnum/2); %starting point for micro cash simulation
zinit = floor(znum/2); %starting point for micro productivity simulation
Kinit = floor(Knum/2); %starting point for macro capital simulation
Ainit = floor(Anum/2); %starting point for macro productivity simulation
Sinit = floor(Snum/2); %starting point for macro sigma simulation
sinit = floor(snum/2); %starting point for micro sigma simulation
Finit = floor(Fnum/2); %starting point for financial shock simulation
randseed = 2501; %random number generator seed, for reproducibility
%rng(randseed);
rng('default');

%set up plotting parameters
lwidnum  = 2; %line width for graphs
fsizenum = 12; %font size for graphs

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%do the exogenous macro simulation outside of the forecast rule loop
Asimpos         = zeros(Ttot,1);        %macro productivity
Asimpos(1)      = Ainit;                %macro productivity
PisumA          = cumsum(PiA,2);        %macro productivity
Ashocks         = rand(Ttot,1);         %macro productivity
sigmaAshocks    = rand(Ttot,1);         %macro uncertainty shock
PisumSigma      = cumsum(PiSigma,2);    %macro uncertainty shock
sigmaAsimpos    = zeros(Ttot,1);        %macro uncertainty shock
sigmaAsimpos(1) = sigmaAinit;           %macro uncertainty shock
Fsimpos         = zeros(Ttot,1);        %macro financial shock
Fsimpos(1)      = Finit;                %macro financial shock
PisumF          = cumsum(PiF,2);        %macro financial shock
Fshocks         = rand(Ttot,1);         %macro financial shock

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 for t = 2 : Ttot
        
     % time-varying macro uncertainty
     Sct       = sigmaAsimpos(t-1);
     sprimect  = find(PisumSigma(Sct,:) >= sigmaAshocks(t), 1);      %compare today's uniform shock to transition matrix intervals
%      sprimect  = sum(PisumSigma(sct,:) < sigmaAshocks(t))+1;
     sigmaAsimpos(t)= sprimect;                                      %store simulated value

     % macro prodcutivity
     Act       = Asimpos(t-1);
     Aprimect  = find(PisumA(Act,:,Sct) >= Ashocks(t), 1);            %compare today's uniform shock to transition matrix intervals
     Asimpos(t) = Aprimect;    %     %store simulated value

     % financial shock
     Fct       = Fsimpos(t-1);
     Fprimect  = find(PisumF(Fct,:) >= Fshocks(t), 1); %            %compare today's uniform shock to transition matrix intervals
     Fsimpos(t)= Fprimect;                                          %store simulated value 
 end
 
sigmaAsim = sigmaA(sigmaAsimpos);                                   % simulated sigma macro unertainty
ZeroBoundary = 1.E-4;                                                  % bound for fixing computer error when computing investment on grid        

z1         = permute(repmat(z0,[1 snum knum nnum knum nnum   pnum]),[5 6 1 2 3 4 7]); %(k',n',z,s,k,n,p)
k1         = permute(repmat(k0,[1 znum snum nnum knum nnum   pnum]),[5 6 2 3 1 4 7]); %(k',n',z,s,k,n,p)
kprime1    = permute(repmat(k0,[1 znum snum knum nnum nnum   pnum]),[1 6 2 3 4 5 7]); %(k',n',z,s,k,n,p)
n1         = permute(repmat(n0',[1 znum snum knum knum nnum  pnum]),[5 6 2 3 4 1 7]); %(k',n',z,s,k,n,p)
nprime1    = permute(repmat(n0',[1 znum snum knum nnum knum  pnum]),[6 1 2 3 4 5 7]); %(k',n',z,s,k,n,p)
p1         = permute(repmat(p0',[1 znum snum knum nnum]),[2 3 4 5 1]); %(z,s,k,n,p)
pmat       = permute(repmat(p0',[1 knum,nnum,znum,snum,knum,nnum]),[2 3 4 5 6 7 1]);%(k',n',z,s,k,n,p)
k1_z_alpha = z1.*k1.^alpha;                 %precomputed output

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% pre-determine matrix
statenum    = Anum*Snum*Fnum*Knum*snum*znum*knum*nnum;
exostatenum = Anum*Snum*Fnum*Knum*snum*znum;

ASKnum     = Anum*Fnum*Knum;
Aggnum     = Anum*Snum*Fnum*Knum;

PiFmat0      = zeros(Fnum,Anum,Snum,Fnum,Knum);
PiAmat0      = zeros(Anum,Anum,Snum,Fnum,Knum);
PiSigmaAmat0 = zeros(Snum,Anum,Snum,Fnum,Knum);

for Act = 1:Anum
    for Sct=1:Snum
        for Fct = 1:Fnum
            for Kct = 1:Knum
                PiFmat0(:,Act,Sct,Fct,Kct)=PiF(Fct,:)';
                PiAmat0(:,Act,Sct,Fct,Kct)=PiA(Act,:,Sct)';
                PiSigmaAmat0(:,Act,Sct,Fct,Kct) = PiSigma(Sct,:)';
            end
        end
    end
end

PiFmat = transpose(reshape(PiFmat0,[Fnum Aggnum]));
PiAmat = transpose(reshape(PiAmat0,[Anum Aggnum]));
PiSigmaAmat = transpose(reshape(PiSigmaAmat0,[Snum Aggnum]));

%%%%
% to be used in parallele
PiFmat01      = zeros(Fnum,Anum,Snum,Fnum,Knum,znum,snum);
PiAmat01      = zeros(Anum,Anum,Snum,Fnum,Knum,znum,snum);
PiSigmaAmat01 = zeros(Snum,Anum,Snum,Fnum,Knum,znum,snum);
Pizmat01      = zeros(znum,Anum,Snum,Fnum,Knum,znum,snum);
Pisigmazmat01 = zeros(Snum,Anum,Snum,Fnum,Knum,znum,snum);
Aggct01 = zeros(Snum,Anum,Snum,Fnum,Knum,znum,snum);

for Act = 1:Anum
    for Sct=1:Snum
        for Fct = 1:Fnum
            for Kct = 1:Knum
                for zct = 1:znum
                    for sct = 1:snum
                
                        PiFmat01(:,Act,Sct,Fct,Kct,zct,sct)=PiF(Fct,:)';
                        PiAmat01(:,Act,Sct,Fct,Kct,zct,sct)=PiA(Act,:,Sct)';
                        PiSigmaAmat01(:,Act,Sct,Fct,Kct,zct,sct) = PiSigma(Sct,:)';
                        Pizmat01(:,Act,Sct,Fct,Kct,zct,sct)=Piz(zct,:,sct)';
                        Pisigmazmat01(:,Act,Sct,Fct,Kct,zct,sct) = Pisigma(sct,:)';
%                         Aggct01(Act,Sct,Fct,Kct,zct,sct)=Knum*(Kct-1)+Fnum*(Fct-1)+Snum*(Sct-1)+Act;
                    end
                end
                
            end
        end
    end
end

PiFmat1 = transpose(reshape(PiFmat01,[Fnum exostatenum]));
PiAmat1 = transpose(reshape(PiAmat01,[Anum exostatenum]));
PiSigmaAmat1 = transpose(reshape(PiSigmaAmat01,[Snum exostatenum]));
Pizmat1 = transpose(reshape(Pizmat01,[znum exostatenum]));
Pisigmazmat1 = transpose(reshape(Pisigmazmat01,[snum exostatenum]));

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% kprimeindold = zeros(Anum,Snum,Fnum,Knum,znum,snum,knum,nnum); %(A,S,F,K,z,s,k,n)
kprimeindold = zeros(statenum,1); %(A,S,F,K,z,s,k,n)
kprimeindold(:) = kinit;
kprimeind    = kprimeindold;
kprimeold    = k0(kprimeindold);
kprime       = k0(kprimeind);

% nprimeindold = zeros(Anum,Snum,Fnum,Knum,znum,snum,knum,nnum); %(A,S,F,K,z,s,k,n)
nprimeindold = zeros(statenum,1); %(A,S,F,K,z,s,k,n)
nprimeindold(:) = ninit;
nprimeind    = nprimeindold;
nprimeold    = n0(nprimeindold);
nprime       = n0(nprimeind);

%set up matrix to calculate return matrix
kmat      = permute(repmat(k0,[1 nnum knum nnum]),[3 4 1 2]); %(k',n',k,n)
kprimemat = permute(repmat(k0,[1 knum nnum nnum]),[1 4 2 3]); %(k',n',k,n)
nmat      = permute(repmat(n0',[1 knum knum nnum]),[3 4 2 1]); %(k',n',k,n)
nprimemat = permute(repmat(n0',[1 knum nnum knum]),[4 1 2 3 ]);%(k',n',k,n)

%[kprimemat, nprimemat, kmat,nmat]  = ndgrid(k0,n0',k0,n0');

imattemp  = kprimemat - (1-delta)*kmat;
imattemp(abs(imattemp)<=ZeroBoundary) =0;
imat      = imattemp;
hmat      = nprimemat -(1+rf)*kappa*nmat;

ivaltemp = kprime1 - (1-delta)*k1;                              %(k',n',z,s,k,n,p)
ivaltemp(abs(ivaltemp)<=ZeroBoundary) =0;
ival     = ivaltemp;
hval     = nprime1 - (1+rf)*kappa*n1;                           %(k',n',z,s,k,n,p)  

Exognum = Anum*Snum*Fnum*Knum*znum*snum;       % Number of exogenous states
Rmat    =  zeros(knum,nnum,knum,nnum,Exognum);  %(k',n',k,n,A*S*F*K*z*s)
zmat    = reshape(permute(repmat(z0,[1 Anum Snum Fnum Knum snum]),[2 3 4 5 1 6]),[Exognum 1]); %(A,S,F,K,z,s) to (A*S*F*K*z*s,1)
Amat    = reshape(repmat(A0,[1 Snum Fnum Knum znum snum]),[Exognum 1]);         %(A,S,F,K,z,s) to (A*S*F*K*z*s,1)
FinCmat = reshape(permute(repmat(FinC,[1 Anum Snum Knum znum snum]),[2 3 1 4 5 6]),[Exognum 1]); %(F) to (A*S*F*K*z*s,1)
   

disp('Finished setting up the model.')
toc;
disp(' ')

%%%%%%%%%%%%
%%%%This block conducts the fcst rule iteration
%%%%%%%%%%%%

disp('Starting fcst rule iteration.')
toc;
%load Theta.mat
%initialize fcst rule and policies and values
thetaKold = zeros(Anum,Snum,Fnum,2); %log K' = thetaK(A,Sigma,F,1) + thetaK(A,Sigma,F,2)*log K
thetaKold(:,:,1,1)=[...
   0.0980    0.0563
    0.0314    0.0196
    0.0394    0.0274
    0.0138    0.0551
    0.0115    0.0242];
thetaKold(:,:,2,1) =[...
    0.0579    0.1674
    0.0984    0.1417
    0.0675    0.1479
    0.0274    0.1650
    0.0886    0.1118];
thetaKold(:,:,1,2)=[...
    0.9823    0.9899
    0.9945    0.9964
    0.9931    0.9950
    0.9977    0.9904
    0.9982    0.9961];
thetaKold(:,:,2,2) =[...
   0.9898    0.9699
    0.9823    0.9744
    0.9876    0.9733
    0.9947    0.9709
    0.9843    0.9801];

thetapold = zeros(Anum,Snum,Fnum,2); %log p = thetap(A,Sigma,F,1) + thetap(A,Sigma,F,2)*log K
thetapold(:,:,1,1)=[...
  3.0109    3.4881
    4.0224    3.7377
    4.0910    3.7880
    4.1301    3.7663
    4.0527    3.8205];
thetapold(:,:,2,1) =[...
    3.1518    3.4240
    3.7172    3.1676
    3.9160    3.2028
    3.8462    3.1063
    3.8223    3.0381];
thetapold(:,:,1,2)=[...
   -0.5099   -0.6037
   -0.6957   -0.6471
   -0.7078   -0.6554
   -0.7139   -0.6511
   -0.6993   -0.6599];
thetapold(:,:,2,2) =[...
   -0.5411   -0.5946
   -0.6438   -0.5492
   -0.6785   -0.5548
   -0.6655   -0.5381
   -0.6610   -0.5255];

Pizz  = reshape(Piz,[znum znum*snum]); %(z,z'*s);
Pizz1 = Pizz(:,1:znum);
Pizz2 = Pizz(:,znum+1:end);
       
pRMSEstore=zeros(Anum,Snum,Fnum,maxfcstit); KRMSEstore =pRMSEstore;pR2store=KRMSEstore;KR2store=KRMSEstore;
KRMSEchangestore = zeros(maxfcstit,1);        pRMSEchangestore = KRMSEchangestore;
KR2changestore   = KRMSEchangestore;          pR2changestore   = KRMSEchangestore; pmaxDenHaanstat = KRMSEchangestore;  
KmaxDenHaanstat  = KRMSEchangestore;          pavgDenHaanchangestore = KRMSEchangestore;
KmaxDenHaanchangestore = KRMSEchangestore;    pavgDenHaanstat  = KRMSEchangestore;
KavgDenHaanstat  = KRMSEchangestore;          pmaxDenHaanchangestore = KRMSEchangestore;
KavgDenHaanchangestore = KRMSEchangestore;

%%
for fcstit=1:maxfcstit

    disp(' ')
    disp(['Doing fcst rule iteration ' num2str(fcstit) '.'])
    disp(' ')

    %initialize the VF and policies
    if (fcstit==1)  
        %in this case, initialize using a stupid guess
        Vold          = ones(Anum,Snum,Fnum,Knum,znum,snum,knum,nnum);%(A,S,F,K,z,s,k,n)
        V             = Vold;
        V_temp        = V;
        kprimeindold = kinit*ones(statenum,1);    %(A*S*F*K*z*s*k*n,1) initialize the index of policies
        nprimeindold = ninit*ones(statenum,1);    %(A*S*F*K*z*s*k*n,1) initialize the index of policies
    elseif (fcstit>1)
        %in this case, initialize using last GE iteration
        Vold = V;
        kprimeindold = kprimeind;
        nprimeindold = nprimeind;
        V_temp       = V;
    end     
    
    %pre compute fcst matrices based on thetaKold and thetapold
    thetaKoldmat = exp(repmat(thetaKold(:,:,:,1),[1 1 1 Knum]) + repmat(thetaKold(:,:,:,2),[1 1 1 Knum]).*permute(repmat(log(K0),[1 Anum Snum Fnum]),[2 3 4 1])); %(A,S,F,K) contains fcst K'(A,S,F,K)
    thetaKoldmat(thetaKoldmat>Kmax) = Kmax-0.001;
    thetaKoldmat(thetaKoldmat<Kmin) = Kmin+0.001;
    
    thetapoldmat = exp(repmat(thetapold(:,:,:,1),[1 1 1 Knum]) + repmat(thetapold(:,:,:,2),[1 1 1 Knum]).*permute(repmat(log(K0),[1 Anum Snum Fnum]),[2 3 4 1])); %(A,S,F,K) contains fcst p'(A,S,F,K)
    thetapoldmat(thetapoldmat>pub) = pub;
    thetapoldmat(thetapoldmat<plb) = plb;
    
    % compute the weight for interpolation
    Kprimewgtmat  = zeros(1, Aggnum);
    Kprimevalmat  = repmat(reshape(thetaKoldmat,[1 Aggnum]),[Knum 1]); %(K',A*S*F*K)
    K0mat         = repmat(K0,[1 Aggnum]); %(K',A*S*F*K)
    Kprimeindmat  = sum(Kprimevalmat>=K0mat,1);         
%    Kprimewgtmat  = (reshape(thetaKoldmat,[1 Anum*Knum*Snum])'-K0(Kprimeindmat))./(K0(Kprimeindmat+1)-K0(Kprimeindmat));
     % guard against off grid point values
     if sum(Kprimeindmat==Knum)>1 
        Kprimeindmat(Kprimeindmat==Knum) = Knum-1;
        Kprimewgtmat(Kprimeindmat==Knum) = 1.0;
     elseif sum(Kprimeindmat==0)>1
         Kprimeindmat(Kprimeindmat==0) = 1;
         Kprimewgtmat(Kprimeindmat==0) = 0;
     else
         Kprimewgtmat  = (reshape(thetaKoldmat,[1 Aggnum])'-K0(Kprimeindmat))./(K0(Kprimeindmat+1)-K0(Kprimeindmat));
     end        
                        
    %set up return matrix
        
        thetapmat  = reshape(repmat(thetapoldmat,[1 1 1 1 znum snum]),[Exognum 1]);     %(A,S,F,K,z,s) to (ASFKzs,1)
        parfor i=1:Exognum
                  ymat  = Amat(i).*zmat(i).*(kmat.^alpha);                    %(k',n',k,n)
                  ACmat = c_k*(imat~=0).*ymat;                                        %(k',n',k,n)
                  Emat  = ymat - imat - ACmat -hmat;                                  %(k',n',k,n)  
                  ECmat = FinCmat(i)*abs(Emat).*(Emat<0);                         %(k',n',k,n)
                  Rmat(:,:,:,:,i)= thetapmat(i)*(Emat - ECmat);                         %(k',n',k,n,ASFKsz)
        end   
        Rmat2 = reshape(permute(Rmat,[1 2 5 3 4]),[knum nnum statenum]);   %(k',n',A*S*F*K*z*s*k*n)   
        disp('Done with return matrix setup.')
      

    %do VF iteration 
    for vfit=1:maxvfit    
         % applying Howard (policy iteration)

            for howct=1:howardnum
            %vectorize: parallel
        
                  EVMAT4  = zeros(knum,nnum,exostatenum); %(k',n',A,S,F,K,z,s)
                  Kprimeindmat3 = reshape(repmat(reshape(Kprimeindmat,[Anum Snum Fnum Knum]),[1 1 1 1 znum snum]),[exostatenum 1]); %(A*S*F*K*z*s)
                  Kprimewgtmat3 = reshape(repmat(reshape(Kprimewgtmat,[Anum Snum Fnum Knum]),[1 1 1 1 znum snum]),[exostatenum 1]); %(A*S*F*K*z*s)
           
                  parfor exostatect = 1:exostatenum    %(A*S*F*K*z*s)
                         temp = squeeze((1-Kprimewgtmat3(exostatect))*Vold(:,:,:,Kprimeindmat3(exostatect),:,:,:,:)...
                                              + Kprimewgtmat3(exostatect)*Vold(:,:,:,Kprimeindmat3(exostatect)+1,:,:,:,:));  %(A',S',F',z',s',k',n')  

                          temp1 =  PiFmat1(exostatect,:)*reshape(PiSigmaAmat1(exostatect,:)*reshape(PiAmat1(exostatect,:)...
                                                                *reshape(temp,[Anum,Snum*Fnum*znum*snum*knum*nnum]),[Snum,Fnum*znum*snum*knum*nnum]),...
                                                                    [Fnum,znum*snum*knum*nnum]);  %(1,S',F',z',s',k',n') to %(S',F,z'*s'*k'*n') to (z',s',k',n')

                          EVMAT4(:,:,exostatect) = reshape(Pisigmazmat1(exostatect,:)*reshape(Pizmat1(exostatect,:)*reshape(temp1,[znum,snum*knum*nnum]),[snum,knum*nnum]),[knum nnum]); 
                                                    % (z',s',k',n') to (z',s'*k'*n') to (1,s',k'*n') to (k',n')
                 end
                 EVMAT_H2    = reshape(repmat(EVMAT4,[1 1 1 knum nnum]),[knum nnum statenum]);
                       
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

                RHSMAT       = Rmat2 + beta*EVMAT_H2;                       %(k',n',A*S*F*K*z*s*k*n)
                k_n_prime_ind= sub2ind(size(RHSMAT),kprimeindold,nprimeindold,(1:statenum)');%(A*S*F*K*z*s*k*n,1)
                Vtemp        = RHSMAT(k_n_prime_ind); %(A*S*F*K*z*s*k*n,1)
                V            = reshape(Vtemp,[Anum Snum Fnum Knum znum snum knum nnum]);   %(A,S,F,K,z,s,k,n)        
                Vold = V;
            end
            
        % compute the expected firm value and find the optimal policy after
            EVMAT3  = zeros(knum,nnum,Aggnum,znum,snum); %(k',n',A,S,F,K,z,s)
             parfor Aggct = 1:Aggnum    %(A*S*F*K)
                    for zct = 1:znum 
                        for sct =1:snum   
                            temp = squeeze((1-Kprimewgtmat(Aggct))*Vold(:,:,:,Kprimeindmat(Aggct),:,:,:,:)...
                                              + Kprimewgtmat(Aggct)*Vold(:,:,:,Kprimeindmat(Aggct)+1,:,:,:,:));  %(A',S',F',z',s',k',n')  

                             temp1 =  PiFmat(Aggct,:)*reshape(PiSigmaAmat(Aggct,:)*reshape(PiAmat(Aggct,:)...
                                                                *reshape(temp,[Anum,Snum*Fnum*znum*snum*knum*nnum]),[Snum,Fnum*znum*snum*knum*nnum]),...
                                                                    [Fnum,znum*snum*knum*nnum]);  %(1,S',F',z',s',k',n') to %(S',F,z'*s'*k'*n') to (z',s',k',n')

                              EVMAT3(:,:,Aggct,zct,sct) = reshape(Pisigma(sct,:)*reshape(Piz(zct,:,sct)*reshape(temp1,[znum,snum*knum*nnum]),[snum,knum*nnum]),[knum nnum]); 
%                                                     (z',s',k',n') to (z',s'*k'*n') to (1,s',k'*n') to    %(k',n')
                                    
                        end
                    end
             end
             EVMAT_H2    = reshape(repmat(EVMAT3,[1 1 1 1 1 knum nnum]),[knum nnum statenum]);
          
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
          RHSMAT       = Rmat2 + beta*EVMAT_H2;                       %(k',n',A*S*F*K*z*s*k*n)
        
        [Vdum,nprimeinddum] = max(max(RHSMAT,[],1),[],2);
        [~,   kprimeinddum] = max(max(RHSMAT,[],2),[],1);
        V         = reshape(squeeze(Vdum),Anum,Snum,Fnum,Knum,znum,snum,knum,nnum);
        kprimeind = squeeze(kprimeinddum);
        nprimeind = squeeze(nprimeinddum);
        kprime    = k0(kprimeind);
        nprime    = n0(nprimeind); 
                
        %evaluate error of VF and policies
        vferr = max(abs(V(:)-Vold(:)));
        polkerr = max(abs(kprime(:)-kprimeold(:)));
        polnerr = max(abs(nprime(:)-nprimeold(:)));
     
        %display diagnostics periodically
        if (mod(vfit,10)==1)
            disp(['On VF iter ' num2str(vfit) ' |V-Vold| = ' num2str(vferr) ' |k - kold| = ' num2str(polkerr) ' |n - nold| = ' num2str(polnerr)])   
        end
        
        %exit VF loop if converged
        if (vferr<maxvferr) 
            break
        end
        
        %if haven't converged, update and move on
        kprimeold    = kprime;
        kprimeindold = kprimeind;
        nprimeold    = nprime;
        nprimeindold = nprimeind;
        Vold         = V;
    end
    
    disp('Done with VFI.')
    disp(' ')
    toc;

    kprime = reshape(kprime, Anum,Snum,Fnum,Knum,znum,snum,knum,nnum);
    nprime = reshape(nprime, Anum,Snum,Fnum,Knum,znum,snum,knum,nnum);
    I = kprime - (1-delta)*permute(repmat(k0,[1 Anum,Snum,Fnum,Knum,znum,snum,nnum]),[ 2 3 4 5 6 7 1 8]);
    H = nprime - (1+rf)*kappa*permute(repmat(n0',[1 Anum,Snum,Fnum,Knum,znum,snum,knum]),[ 2 3 4 5 6 7 8 1 ]);
  %%   
    %simulate the model, with market clearing loop
    
    %initialize macro capital/price & micro dist
    Ksim     = zeros(Ttot,1);
    Ksim(1)  = K0(Kinit);
    psim     = zeros(Ttot,1);
    Csim     = zeros(Ttot,1);
    Isim     = zeros(Ttot,1);
    ACsim    = zeros(Ttot,1);
    Ysim     = zeros(Ttot,1);
    Kfcstsim = zeros(Ttot,1);
    pfcstsim = zeros(Ttot,1);
    perrsim  = zeros(Ttot,1);
    Nsim     = zeros(Ttot,1);
    Lsim     = zeros(Ttot,1);
%     load distold.mat
    distold          = zeros(znum,snum,knum,nnum);
    distzskn         = zeros(znum,snum,knum,nnum,Ttot);
    distold(zinit,:,:,:) = 1;
    distold          = distold/sum(distold(:));   
    distzskn(:,:,:,:,1) = distold;
    
    %actually conduct the simulation
    disp('Simulating the model.')

    for t=1:(Ttot-1)
        %extract macro prod
        Act       = Asimpos(t);
        Aval      = A0(Act);
        Sct       = sigmaAsimpos(t);
        sigmaAval = sigmaA(Sct);
        Fct       = Fsimpos(t);
        FSval     = FinC(Fct);     % financial cost is function of sigma and financial shock
        %extract macro capital and forecasts
        Kval      = Ksim(t);
        Kval      = min(max(Kval,Kmin),Kmax);
        Kprimeval = exp(thetaKold(Act,Sct,Fct,1) + thetaKold(Act,Sct,Fct,2)*log(Kval));
        Kprimeind = sum(Kprimeval>=K0);
        
        % guard against off grid point values
    	if (Kprimeind==Knum) 
        	Kprimeind = Knum-1;
	        Kprimewgt = 1.0;
        elseif (Kprimeind==0)
        	Kprimeind = 1;
	        Kprimewgt = 0.0;
        else
           Kprimewgt  = (Kprimeval-K0(Kprimeind))/(K0(Kprimeind+1)-K0(Kprimeind));
    	end 
        Kfcstsim(t+1) = Kprimeval;
        
        pfcstval    = exp(thetapold(Act,Sct,Fct,1) + thetapold(Act,Sct,Fct,2)*log(Kval));
        pfcstsim(t) = pfcstval;
        
        %initialize price iteration
       	%set up the interpolation of the excess demand function e(p) = 1/p - C(p)
 
        %note that EVmat doesn't depend on p, so it can be pre-computed now, as EVMAT_H
        % compute the expected firm value  

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
          %1. Follow step by step computing the conditional expectation
          %E_t(V(A',S',F',K',z',s',k',n')
            disttemp   = repmat(distzskn(:,:,:,:,t),[1 1 1 1 pnum]);%(z,s,k,n,p)
            
            EVMAT    = zeros(knum,nnum,znum,snum);
            for zct = 1:znum 
                 for sct =1:snum   
                        temp = (1-Kprimewgt)*squeeze(V(:,:,:,Kprimeind,:,:,:,:))...
                                        + Kprimewgt*squeeze(V(:,:,:,Kprimeind+1,:,:,:,:));             %(A',S',F',z',s',k',n')                                
                        temp1 = PiF(Fct,:)*reshape(PiSigma(Sct,:)*reshape(PiA(Act,:,Sct)*reshape(temp,...
                                [Anum,Snum*Fnum*znum*snum*knum*nnum]),...
                                [Snum,Fnum*znum*snum*knum*nnum]),...
                                [Fnum,znum*snum*knum*nnum]);  
                        EVMAT(:,:,zct,sct) = reshape(Pisigma(sct,:)*reshape(Piz(zct,:,sct)*reshape(temp1,[znum,snum*knum*nnum]),[snum,knum*nnum]),[knum nnum]); 
                                    
                 end
            end
        
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        
          EVMAT_H = repmat(EVMAT,[1 1 1 1 knum nnum pnum]);               %(k',n',z,s,k,n,p)
          yval    = Aval*k1_z_alpha;                             %(k',n',z,s,k,n,p)
          ACval    = c_k*(ival~=0).*yval;                                 %(k',n',z,s,k,n,p)
          Eval     = (yval - ival - ACval -hval);                         %(k',n',z,s,k,n,p)
          ECval    = FSval*abs(Eval).*(Eval<0);                           %(k',n',z,s,k,n,p)
          Rmatp    = pmat.*(Eval-ECval);                                 %(k',n',z,s,k,n,p)

          RHSMAT2 = Rmatp + beta*EVMAT_H;
          [~,kprimeinddum] = max(max(RHSMAT2,[],2),[],1);
          [~,nprimeinddum] = max(max(RHSMAT2,[],1),[],2);
          kprimeindp = squeeze(kprimeinddum);
          nprimeindp = squeeze(nprimeinddum);
         
          kprimep    = k0(kprimeindp);
          nprimep    = n0(nprimeindp);
                 
           yvalp = squeeze(yval(1,1,:,:,:,:,:));
           ivalptemp = (kprimep - (1-delta)*squeeze(k1(1,1,:,:,:,:,:)));
           ivalptemp(abs(ivalptemp)<=ZeroBoundary)=0; 
           ivalp = ivalptemp;
           acvalp = c_k*(ivalp~=0).*yvalp;
           hvalp  = nprimep-(1+rf)*kappa*squeeze(n1(1,1,:,:,:,:,:));
           evalp  = (yvalp-ivalp-acvalp-hvalp);
           ecvalp = FSval*abs(evalp).*(evalp<0);

           Yp0        = squeeze(sum(sum(sum(sum(yvalp.*disttemp)))));
           Ip0        = squeeze(sum(sum(sum(sum(ivalp.*disttemp)))));
           ACp0       = squeeze(sum(sum(sum(sum(acvalp.*disttemp)))));
           Kbarprimep0= squeeze(sum(sum(sum(sum(kprimep.*disttemp)))));
           Nbarprimep0= squeeze(sum(sum(sum(sum(nprimep.*disttemp)))));
           ECp0       = squeeze(sum(sum(sum(sum(ecvalp.*disttemp)))));
           polmatp_interp_k = kprimeindp; %(z,s,k,n,p)
           polmatp_interp_n = nprimeindp; %(z,s,k,n,p)
     
           Cvalp = Yp0-Ip0-ACp0-ECp0;
           Cp0   = Cvalp;
           ep0   = 1./p0'-Cvalp;
        
  %set up the boundaries of the bisection
    pvala = plb;
    pvalb = pub;
    pvalc = pvala + (0.67)*(pvalb-pvala);

    %iterate over the market-clearing value of p, for each value
    %redoing the optimization of the RHS of the Bellman equation
    %and recomputing policies
    for piter=1:maxpit
        %initialize the price-dependent policies to 0
%         kprimeinddum(:,:) = 0; kprimeindp(:,:) = 0; 
        
        %brent setup
        if (piter==1); pval = pvala;end
        if (piter==2); pval = pvalb;end
        if (piter==3); pval = pvalc;end
        
        %brent restart
        if (piter==pcutoff)
            pvala = pfcstsim(t)-4.0*pwindow;
            pval = pvala;
        elseif (piter==pcutoff+1) 
            pvalb =  pfcstsim(t)+4.0*pwindow;
            pval = pvalb;
        elseif (piter==pcutoff+2)
            pvalc = pvala + (0.67) * (pvalb-pvala); 
            pval = pvalc;
        end 
        
        %if not restarting or initializing
        if ((piter>3 && piter<pcutoff) || (piter>pcutoff+2)) 
            %first, try inverse quadratic interpolation of the excess demand function
            pval = ( pvala * fb * fc ) / ( (fa - fb) * (fa - fc) ) ...
                + ( pvalb * fa * fc ) / ( (fb - fa) * (fb - fc ) ) ...
                + ( pvalc * fa * fb ) / ( (fc - fa) * (fc - fb ) );
            %if it lies within bounds, and isn't too close to the bounds, then done
            %o/w, take bisection step
            if (min(abs(pvala - pval),abs(pvalb-pval))<abs((pvalb-pvala)/9)) || (pval<pvala) || (pval>pvalb)
%                 ((minval( (/ abs(pvala - pval), abs(pvalb-pval) /) )<&
%                     abs( (pvalb-pvala)/dble(9.0) ) ).or.(pval<pvala).or.(pval>pvalb))   then
                pval = (pvala + pvalb) /2;
            end
        end
        
		%actually evaluate consumption approximation function
		pval = min(max(p0(1),pval),p0(pnum));         %minval( (/ maxval( (/ p0(1) , pval /) ) , p0(pnum) /) )   
	    pind = sum(pval>=p0);%pind = pnum/2;  call hunt(p0,pnum,pval,pind)
% 	    pwgt = (pval - p0(pind)) / (p0(pind + 1) - p0(pind));
	    
	    % guard against off grid point values
    	if (pind==pnum) 
        	pind = pnum-1;
	        pwgt = 1.0;
        elseif (pind==0)
        	pind = 1;
	        pwgt = 0.0;
        else
            pwgt = (pval - p0(pind)) / (p0(pind + 1) - p0(pind));
    	end 
		
		%actually perform linear interpolation of the consumption function
		CvalpNew = Cp0(pind)*(1-pwgt) + Cp0(pind+1)*pwgt;
		        
        %are you initializing the brent?
        if (piter==1); fa = (1/pval) - CvalpNew;end
        if (piter==2); fb = (1/pval) - CvalpNew;end
        if (piter==3); fc = (1/pval) - CvalpNew;end
        
        %are you restarting the brent?
        if (piter==pcutoff);   fa = (1/pval) - CvalpNew; end
        if (piter==pcutoff+1); fb = (1/pval) - CvalpNew; end
        if (piter==pcutoff+2); fc = (1/pval) - CvalpNew; end
        
        %what is the error given by this implied consumption?
        perror = (1/pval) - CvalpNew;
        
        %if not restarting or initializing
        if ((piter>3 && piter<pcutoff)||(piter>pcutoff+2))  
            if (perror<0) 
                pvalc = pvalb; fc = fb;
                pvalb = pval; fb = perror;
                %pval a doesn't change
            elseif (perror>=0) 
                pvalc = pvala; fc = fa;
                pvala = pval; fa = perror;
                %pval b doesn't change
            end
        end
      
        %exit criterion for market-clearing
        perror = log(pval*CvalpNew);
        if (abs(perror)<perrortol && piter>2) 
            break
        end  %piter
        
        
    end
    
    %insert market-clearing price and other linearly interpolated stuff into sim series
    psim(t)     = pval; %this is the most recently run p from the clearing algorithm
    Csim(t)     = CvalpNew; %this is already the linearly interpolated consumption C
    Ysim(t)     = Yp0(pind)*(1.0-pwgt) + Yp0(pind+1)*pwgt; %linearly interpolated output Y
    Isim(t)     = Ip0(pind)*(1.0-pwgt) + Ip0(pind+1)*pwgt; %linearly interpolated investment I
    ACsim(t)    = ACp0(pind)*(1.0-pwgt) + ACp0(pind+1)*pwgt; %linearly interpolated ACk 
    Ksim(t+1)   = Kbarprimep0(pind)*(1.0-pwgt) + Kbarprimep0(pind+1)*pwgt; %linearly interpolated K' 
 %   Hsim(t)    = Hp0(pind)*(1.0-pwgt) + Hp0(pind+1)*pwgt; %linearly interpolated hiring H
    Nsim(t+1)   = Nbarprimep0(pind)*(1.0-pwgt) + Nbarprimep0(pind+1)*pwgt; %linearly interpolated labor input L
    
    %now that the market-clearing price is determined, move on to insert weight into the next period, according to the 
    %linearly interpolated rule

      %output price stats on certain periods
      if (mod(t,50)==1)
            disp(['t = ',num2str(t), ' iter ',num2str(piter),' p ',num2str(pval),' err ',num2str(perror), ' A ',num2str(A0(Act)),' Kprime ' ,num2str(Ksim(t+1)),...
                ' Nprime ' ,num2str(Nsim(t+1))])        
      end
    
   for zct=1:znum
       for sct =1:snum
        for kct=1:knum 
            for nct=1:nnum 
                if (distzskn(zct,sct,kct,nct,t)>disttol)       

                %based on the latest price, what is the policy here at pind?
                polstar_k_0 = polmatp_interp_k(zct,sct,kct,nct,pind);
                polstar_n_0 = polmatp_interp_n(zct,sct,kct,nct,pind);

                %insert distributional weight in appropriate slots next period
                distzskn(:,:,polstar_k_0,polstar_n_0,t+1) = distzskn(:,:,polstar_k_0,polstar_n_0,t+1) + ...
                    repmat(Piz(zct,:,sct)',1,snum).*repmat(Pisigma(sct,:),znum,1)*distzskn(zct,sct,kct,nct,t) * (1.0-pwgt);

                %based on the latest price, what is the policy here at pind+1?
                polstar_k_1 = polmatp_interp_k(zct,sct,kct,nct,pind+1);
                polstar_n_1 = polmatp_interp_n(zct,sct,kct,nct,pind+1);

                %insert distributional weight in appropriate slots next period
                distzskn(:,:,polstar_k_1,polstar_n_1,t+1) = distzskn(:,:,polstar_k_1,polstar_n_1,t+1) + ...
                    repmat(Piz(zct,:,sct)',1,snum).*repmat(Pisigma(sct,:),znum,1)*distzskn(zct,sct,kct,nct,t) * pwgt;
                end 
            end %nct
        end  %kct
       end %sct
   end  %zct
    
%     %now, round to make sure that you're ending up with a distribution which makes sense each period
     distzskn(:,:,:,:,t+1) = distzskn(:,:,:,:,t+1)./sum(sum(sum(sum(distzskn(:,:,:,:,t+1)))));
        
    end
    
    if (checkbounds==1) 
    
  
        TOPksim = squeeze(sum(sum(sum(distzskn(:,:,knum,:,:)))));
        BOTksim = squeeze(sum(sum(sum(distzskn(:,:,1,:,:)))));
        TOPnsim = squeeze(sum(sum(sum(distzskn(:,:,:,nnum,:)))));
        BOTnsim = squeeze(sum(sum(sum(distzskn(:,:,:,1,:)))));
    
        TOPkmax = max(TOPksim((Terg+1):(Ttot-1)));
        BOTkmax = max(BOTksim((Terg+1):(Ttot-1)));
        TOPnmax = max(TOPnsim((Terg+1):(Ttot-1)));
        BOTnmax = max(BOTnsim((Terg+1):(Ttot-1)));
        
        disp(' ')
        disp('Checking for bounds of state space in simulation.');
        disp(['Top k = ' num2str(TOPkmax),'Bottom k = ' num2str(BOTkmax), 'Top n = ' num2str(TOPnmax),'Bottom n = ' num2str(BOTnmax), ])    
   
    end
    

%%       
    disp('Done with simulation.')
    toc;
        
    %update the forecast rules, test for convergence
    Kest = Ksim((Terg):(Ttot-2));
    pest = psim((Terg):(Ttot-2));
    Kprimeest       = Ksim((Terg+1):(Ttot-1));
    Asimposest      = Asimpos((Terg):(Ttot-2));
    sigmaAsimposest = sigmaAsimpos((Terg):(Ttot-2));
    Fsimposest      = Fsimpos((Terg):(Ttot-2));
    
    Aobsct = zeros(Anum,1);
    KPred  = zeros(length(Kprimeest),1);
    pPred  = zeros(length(pest),1);
    
    for Act=1:Anum
        for Sct = 1:Snum
            for Fct = 1:Fnum
               samp = (Asimposest==Act & sigmaAsimposest==Sct & Fsimposest==Fct);
               X = [ones(sum(samp),1) log(Kest(samp))];
               Y = log(Kprimeest(samp));
               betaOLS = (X'*X)\(X'*Y);
               thetaK(Act,Sct,Fct,:) = betaOLS';
               KPred(samp) = betaOLS'*X';

               X = [ones(sum(samp),1) log(Kest(samp))];
               Y = log(pest(samp));
               betaOLS = (X'*X)\(X'*Y);
               thetap(Act,Sct,Fct,:) = betaOLS';
               pPred(samp) = betaOLS'*X';


               Aobsct(Act,Sct,Fct) = sum(samp);
            end
        end
        
    end
    
    [~, ~, ~, ~, e] = regress(KPred,[ones(length(KPred),1) Kprimeest]); R2K(fcstit) = e(1);
    [~, ~, ~, ~, e] = regress(pPred,[ones(length(pPred),1) pest]);      R2p(fcstit) = e(1);
    
    
    
    thetaKerr = max(abs(thetaK(:)-thetaKold(:)));
    thetaperr = max(abs(thetap(:)-thetapold(:)));
    fcsterr = max([thetaKerr,thetaperr]);
    
    disp (' ')
    disp('Fcst rules & new estimates')
    disp(['Max abs err ' num2str(fcsterr)])
    disp(' ')
    disp('Capital: Old, New')
    disp(num2str([thetaKold thetaK]))
    disp('Capital: Err')
    disp(num2str([thetaK-thetaKold]))
    disp(' ')
    disp('Price: Old, New')
    disp(num2str([thetapold thetap]))
    disp('Price: Err')
    disp(num2str([thetap-thetapold]))

     disp('R2K R2p')
    disp(num2str([R2K(fcstit),R2p(fcstit)]))
   
    
%     if (fcsterr<maxfcsterr) 
%        break 
%     end
    
    pmean = zeros(Anum,Snum,Fnum); pSSE = pmean; Kmean=pmean; KSSE=pmean;pSST=pmean; KSST = pmean;
    
    %loop over aggregate states
    for Act=1:Anum
        for Sct=1:Snum
            for Fct=1:Fnum
        %find mean and fcst errors first for price
                perct = 0;
                for t=(Terg+1):(Ttot-1)
                    if Asimpos(t)==Act && sigmaAsimpos(t)==Sct && Fsimpos(t)==Fct
                        perct = perct+1;
                        pmean(Act,Sct,Fct) = pmean(Act,Sct,Fct) + log(psim(t));
                        pSSE(Act,Sct,Fct)  = pSSE(Act,Sct,Fct) + ( log(psim(t)) - log(pfcstsim(t)) )^2;
                    end %Act and Sct flag
                end %t

                %normalize mean, compute RMSE for price
                pmean(Act,Sct,Fct) = pmean(Act,Sct,Fct)/perct;
                pRMSEstore(Act,Sct,Fct,fcstit) = sqrt( pSSE(Act,Sct,Fct) / perct);

            %then for capital
                perct = 0;
                for t=(Terg+1):(Ttot-1)
                    if Asimpos(t)==Act&& sigmaAsimpos(t)==Sct && Fsimpos(t)==Fct
                        perct = perct+1;
                        Kmean(Act,Sct,Fct) = Kmean(Act,Sct,Fct) + log(Ksim(t));
                        KSSE(Act,Sct,Fct) = KSSE(Act,Sct,Fct) + ( log(Ksim(t)) - log(Kfcstsim(t)) )^2;
                    end %Act and Sct flag
                end %t

                %normalize mean, compute RMSE for capital
                Kmean(Act,Sct,Fct) = Kmean(Act,Sct,Fct)/perct;
                KRMSEstore(Act,Sct,Fct,fcstit) = sqrt( KSSE(Act,Sct,Fct) / perct);

            %now compute the total sum squares
                for t= (Terg+1):(Ttot-1)
                    if Asimpos(t)==Act&& sigmaAsimpos(t)==Sct && Fsimpos(t)==Fct
                        pSST(Act,Sct,Fct) = pSST(Act,Sct,Fct) + (log(psim(t)) - pmean(Act,Sct,Fct))^2;
                        KSST(Act,Sct,Fct) = KSST(Act,Sct,Fct) + (log(Ksim(t)) - Kmean(Act,Sct,Fct))^2;
                    end %Act and Sct flag
                end %t

                %now, compute the R^2 values
                pR2store(Act,Sct,Fct,fcstit) = 1 - ( pSSE(Act,Sct,Fct) / pSST(Act,Sct,Fct) );
                KR2store(Act,Sct,Fct,fcstit) = 1- ( KSSE(Act,Sct,Fct) / KSST(Act,Sct,Fct) );
            end %Fct
        end %Sct
    end %Act

    
      %record the change of the RMSE's and the R^2's
        if (fcstit==1) 

            %in this case, it isn't change, just absolute RMSE and distance of R^2 from 1
            KRMSEchange = max(max(max(abs(KRMSEstore(:,:,:,fcstit)))));
            pRMSEchange = max(max(max(abs(pRMSEstore(:,:,:,fcstit)))));
            KR2change = max(max(max(abs(1 - KR2store(:,:,:,fcstit)))));
            pR2change = max(max(max(abs(1 - pR2store(:,:,:,fcstit)))));
        else 

            %now, it's the change in the metrics
            KRMSEchange = max(max(max(abs(KRMSEstore(:,:,:,fcstit)-KRMSEstore(:,:,:,fcstit-1)))));
            pRMSEchange = max(max(max(abs(pRMSEstore(:,:,:,fcstit)-pRMSEstore(:,:,:,fcstit-1)))));
            KR2change = max(max(max(abs(KR2store(:,:,:,fcstit)-KR2store(:,:,:,fcstit-1)))));
            pR2change = max(max(max(abs(pR2store(:,:,:,fcstit)-pR2store(:,:,:,fcstit-1)))));
        end 

        %store the RMSE and R2 errors
        KRMSEchangestore(fcstit) = KRMSEchange;
        pRMSEchangestore(fcstit) = pRMSEchange;
        KR2changestore(fcstit) = KR2change;
        pR2changestore(fcstit) = pR2change;
  
    
    %%%Compute the Den Haan fcst series

        %initialize the entire fcst series
        pDenHaanfcst = zeros(Ttot,1);
        KDenHaanfcst = zeros(Ttot,1);

        %initialize the period before the sample
        pDenHaanfcst(Terg) = log(psim(Terg));
        KDenHaanfcst(Terg) = log(Ksim(Terg));

        %initialize the statistics
        pmaxDenHaanstat(fcstit) = 0;
        KmaxDenHaanstat(fcstit) = 0;
        pavgDenHaanstat(fcstit) = 0;
        KavgDenHaanstat(fcstit) = 0;

        %loop over the periods
        for t= (Terg+1):(Ttot-1)

            %get states from last period, when the capital forecast was being made
            Act = Asimpos(t-1);
            Sct = sigmaAsimpos(t-1);
            Fct = Fsimpos(t-1);

            %insert the capital fcst
            KDenHaanfcst(t) = thetaK(Act,Sct,Fct,1) + thetaK(Act,Sct,Fct,2)*KDenHaanfcst(t-1);

            %now, get states from this period, when the price forecast is being made
            Act = Asimpos(t);
            Sct = sigmaAsimpos(t);
            Fct = Fsimpos(t);
            
            %insert the price fcst
            pDenHaanfcst(t) = thetap(Act,Sct,Fct,1) + thetap(Act,Sct,Fct,2)*KDenHaanfcst(t);

            %iterate the max statistics
            pmaxDenHaanstat(fcstit) = max(pmaxDenHaanstat(fcstit), abs(pDenHaanfcst(t) - log(psim(t))));
            KmaxDenHaanstat(fcstit) = max(KmaxDenHaanstat(fcstit), abs(KDenHaanfcst(t) - log(Ksim(t))));

            %iterate the average statistics
            pavgDenHaanstat(fcstit) = pavgDenHaanstat(fcstit) + abs(pDenHaanfcst(t) - log(psim(t)));
            KavgDenHaanstat(fcstit) = KavgDenHaanstat(fcstit) + abs(KDenHaanfcst(t) - log(Ksim(t)));

        end  %t

        %normalize the average statistics
        pavgDenHaanstat(fcstit) = pavgDenHaanstat(fcstit) / (Ttot-Terg-1);
        KavgDenHaanstat(fcstit) = KavgDenHaanstat(fcstit) / (Ttot-Terg-1);

        %record the change of the avg and max statistics
        if (fcstit==1) 

            pmaxDenHaanchange = pmaxDenHaanstat(fcstit);
            KmaxDenHaanchange = KmaxDenHaanstat(fcstit);
            pavgDenHaanchange = pavgDenHaanstat(fcstit);
            KavgDenHaanchange = KavgDenHaanstat(fcstit);

        elseif (fcstit>1)

            pmaxDenHaanchange = abs(pmaxDenHaanstat(fcstit)-pmaxDenHaanstat(fcstit-1));
            KmaxDenHaanchange = abs(KmaxDenHaanstat(fcstit)-KmaxDenHaanstat(fcstit-1));
            pavgDenHaanchange = abs(pavgDenHaanstat(fcstit)-pavgDenHaanstat(fcstit-1));
            KavgDenHaanchange = abs(KavgDenHaanstat(fcstit)-KavgDenHaanstat(fcstit-1));

        end

        %store the changes    
        pmaxDenHaanchangestore(fcstit) = pmaxDenHaanchange;
        KmaxDenHaanchangestore(fcstit) = KmaxDenHaanchange;
        pavgDenHaanchangestore(fcstit) = pavgDenHaanchange;
        KavgDenHaanchangestore(fcstit) = KavgDenHaanchange;

        
    %exit criterion for the fcstit looop
    exitflag = 0;
    %coefficient convergence
    if (GEerrorswitch==1)
        if (thetaKerr<fcsterrortol && thetaperr<fcsterrortol), exitflag = 1;end
    %RMSE convergence    
    elseif (GEerrorswitch==2)
        if (KRMSEchange<RMSEchangetol && pRMSEchange<RMSEchangetol), exitflag = 1;end
    %R2 convergence    
    elseif (GEerrorswitch==3)
        if (KR2change<R2changetol && pR2change<R2changetol), exitflag = 1;end
    %max Den Haan stat convergence    
    elseif (GEerrorswitch==4)
        if (KmaxDenHaanchange<maxDenHaanchangetol && pmaxDenHaanchange<maxDenHaanchangetol), exitflag = 1;  end  
    %avg Den Haan stat convergence    
    elseif (GEerrorswitch==5)
        if (KavgDenHaanchange<avgDenHaanchangetol && pavgDenHaanchange<avgDenHaanchangetol), exitflag = 1;  end      
    end
   
      disp(['fcst error = ',num2str(max(thetaKerr,thetaperr)), ' RMSE ',num2str(max(KRMSEchange,pRMSEchange)),...
        ' R2 change ',num2str(max(KR2change,pR2change)),' DH max ',num2str(max(KmaxDenHaanchange,pmaxDenHaanchange)),...
        ' DH avg ',num2str(max(KavgDenHaanchange,pavgDenHaanchange))])   
    
    if exitflag ==1
        break
    end
  
    
    thetaKold = thetaKold + fcstdamp*(thetaK-thetaKold);
    thetapold = thetapold + fcstdamp*(thetap-thetapold);
    
    save('Theta.mat','thetaKold','thetapold');

    clear Vold V_temp Rmat2 EVMAT_H2 Rmatp RHSMAT2 nprimeinddum lval kprimeindold Azmat_ly disttemp samp 
    clear kprimeinddum EVMAT_H Eval ECval ACval ACmat ACval acvalp 
    clear  distold temp yvalp lvalp ivalptemp distemp kprimemat kmat_ly k_n_prime_ind EVMAT  Vtemp nprimeindold  
    clear RHSMAT yval z1 y1 tempz KSSE KSST Rmat polmatp_interp_k polmatp_interp_n Rmat Emat ECmat
    clear nprimep nprimeindp Ivalp kprimep kprimeindp ivaltemp ivalp hvalp evalp ecvalp distemp  tempzs 
  %  clear  nprimemat nmat lmat kmat imattemp imat hmat 
  %  clear  zmat FinSigmamat Amat PiFmat PiSigmamat PiAmat 

   %   clear Aggct01 EVMAT3 EVMAT4 imattemp z1 k1 kprime1 n1 nprime1 p1 pmat k1_z_alpha kprimeest %kprimeind kprimeold
   %   clear nprimeind nprimemat nprimeold 

    save ('ks_het_firms_temp.mat');

    
end %of fcst loop


disp('Finished fcst rule iteration. The model is solved!')
toc;
disp(' ')

if (diary_record)
    diary OFF
    
    clear Vold V_temp Rmat2 EVMAT_H2 Rmatp RHSMAT2 nprimeinddum lval
    clear kprimeinddum ivaltemp ival hval havl EVMAT_H Eval ECval ACval ACmat ACval acvalp 
    clear  distold temp yvalp lvalp ivalptemp distemp kprimemat kmat_ly kprimeind k_n_prime_ind 
    clear RHSMAT yval z1 y1 tempz KSSE KSST polmatp_interp_k polmatp_interp_n 
    clear nprimep nprimeindp Ivalp kprimep kprimeindp ivaltemp ivalp hvalp evalp ecvalp distemp  
   % clear tempzs nprimemat nmat lmat kmat imattemp imat hmat EVMAT Emat ECmat Vtemp nprimeindold nprimeind 
   % clear kprimeold kprimeindold zmat FinSigmamat Amat Azmat_ly disttemp samp PiFmat PiSigmamat PiAmat 
    save ('ks_het_firms_baseline.mat');
end

    
